from tabulate import tabulate

names = (
    "L1_B",
    "L1_K",
    "L1_M",
    "L1_N",
    "L2_B",
    "L2_K",
    "L2_M",
    "L2_N",
    "L2_wgt_B",
    "L2_wgt_K",
    "HEADS",
    "L3_K",
    "L3_M",
    "L3_N",
)
dims = [
    (1, 64, 16, 32, 1, 2560, 16, 128, 1, 64, 1, 2560, 1024 * 2, 640),
    (1, 64, 64, 32, 1, 1280, 64, 128, 1, 64, 1, 1280, 64 * 2, 10240),
    (1, 64, 64, 32, 1, 1280, 64, 128, 1, 64, 1, 1280, 64 * 2, 1280),
    (1, 64, 16, 32, 1, 5120, 16, 128, 1, 64, 1, 5120, 64 * 2, 1280),
    (1, 32, 80, 32, 1, 768, 80, 128, 1, 32, 1, 768, 80 * 2, 1280),
    (1, 32, 80, 32, 1, 768, 80, 128, 1, 32, 1, 768, 80 * 2, 320),
    (1, 32, 80, 32, 1, 768, 80, 128, 1, 32, 1, 768, 80 * 2, 640),
    (1, 64, 16, 32, 1, 640, 16, 128, 1, 64, 1, 640, 1024 * 2, 5120),
    (1, 64, 16, 32, 1, 640, 16, 128, 1, 64, 1, 640, 1024 * 2, 640),
    (1, 64, 16, 32, 1, 1280, 16, 128, 1, 64, 1, 1280, 256 * 2, 10240),
    (1, 64, 16, 32, 1, 1280, 16, 128, 1, 64, 1, 1280, 256 * 2, 1280),
    (1, 64, 16, 32, 1, 5120, 16, 128, 1, 64, 1, 5120, 256 * 2, 1280),
    (1, 64, 16, 32, 1, 1280, 16, 128, 1, 64, 1, 1280, 4096 * 2, 320),
    (1, 64, 16, 32, 1, 320, 16, 128, 1, 64, 1, 320, 4096 * 2, 2560),
    (1, 64, 16, 32, 1, 320, 16, 128, 1, 64, 1, 320, 4096 * 2, 320),
    (1, 64, 16, 32, 1, 1280, 16, 128, 1, 64, 1, 1280, 2, 1280),
    (1, 64, 16, 32, 1, 1280, 16, 128, 1, 64, 1, 1280, 2, 320),
    (1, 64, 16, 32, 1, 1280, 16, 128, 1, 64, 1, 1280, 2, 640),
    (1, 64, 16, 32, 1, 320, 16, 128, 1, 64, 1, 320, 2, 1280),
]

print(tabulate(dims, headers=(names)))


a = [
    (5120, 40, 128, 5120, 64),
    (400, 20, 20, 400, 1),
    (60, 20, 3, 60, 1),
    (960, 80, 12, 960, 4),
    (72, 24, 3, 72, 1),
    (24, 24, 1, 24, 1),
    (48, 24, 2, 48, 1),
    (6400, 10, 640, 6400, 64),
    (1280, 10, 128, 1280, 64),
    (6400, 20, 320, 6400, 16),
    (960, 20, 48, 960, 16),
    (3840, 80, 48, 3840, 16),
    (5120, 20, 256, 5120, 256),
    (6400, 5, 1280, 6400, 256),
    (1280, 5, 256, 1280, 256),
]
